/*
Copyright 2025 The Vitess Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package asthelpergen

import (
	"fmt"
	"go/types"

	"github.com/dave/jennifer/jen"
)

type (
	pathGen struct {
		file      *jen.File
		steps     []step
		ifaceName string
	}
	step struct {
		container types.Type // the type of the container
		typ       types.Type // the type of the field
		name      string     // the name of the field
		slice     bool       // whether the field is a slice
	}
)

const sliceMarker = " -SLICE- "

var _ generator = (*pathGen)(nil)

func newPathGen(pkgname, ifaceName string) *pathGen {
	file := jen.NewFile(pkgname)
	file.HeaderComment(licenseFileHeader)
	file.HeaderComment("Code generated by ASTHelperGen. DO NOT EDIT.")

	return &pathGen{
		file:      file,
		ifaceName: ifaceName,
	}
}

func (s step) AsEnum() string {
	if s.name == sliceMarker {
		return printableTypeName(s.container)
	}
	return printableTypeName(s.container) + s.name
}

func (p *pathGen) genFile(spi generatorSPI) (string, *jen.File) {
	p.file.ImportName("fmt", "fmt")

	// Declare the ASTStep type with underlying type uint16
	p.file.Add(jen.Type().Id("ASTStep").Uint16())

	// Add the const block
	p.file.Add(p.buildConstWithEnum())

	// Add the ASTStep#DebugString() method to the file
	p.file.Add(p.debugString())

	// Add the generated GetNodeFromPath method
	p.file.Add(p.generateGetNodeFromPath(spi))

	return "ast_path.go", p.file
}

func (p *pathGen) interfaceMethod(t types.Type, iface *types.Interface, spi generatorSPI) error {
	return nil
}

func (p *pathGen) structMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
	p.addStructFields(t, strct, spi)
	return nil
}

func (p *pathGen) ptrToStructMethod(t types.Type, strct *types.Struct, spi generatorSPI) error {
	p.addStructFields(t, strct, spi)
	return nil
}

func (p *pathGen) addStep(
	container types.Type, // the name of the container type
	typ types.Type, // the type of the field
	name string, // the name of the field
	slice bool, // whether the field is a slice
) {
	s := step{
		container: container,
		name:      name,
		typ:       typ,
		slice:     slice,
	}
	p.steps = append(p.steps, s)
}

func (p *pathGen) addStructFields(t types.Type, strct *types.Struct, spi generatorSPI) {
	val := types.TypeString(t, noQualifier)
	_ = val
	for i := 0; i < strct.NumFields(); i++ {
		field := strct.Field(i)
		// Check if the field type implements the interface
		if types.Implements(field.Type(), spi.iface()) {
			p.addStep(t, field.Type(), field.Name(), false)
			continue
		}
		// Check if the field type is a slice
		slice, isSlice := field.Type().(*types.Slice)
		if isSlice {
			// Check if the slice type implements the interface
			if types.Implements(slice, spi.iface()) {
				p.addStep(t, slice, field.Name(), true)
			} else if types.Implements(slice.Elem(), spi.iface()) {
				// Check if the element type of the slice implements the interface
				p.addStep(t, slice.Elem(), field.Name(), true)
			}
		}
	}
}

func (p *pathGen) ptrToBasicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error {
	return nil
}

func (p *pathGen) sliceMethod(t types.Type, slice *types.Slice, spi generatorSPI) error {
	elemType := slice.Elem()
	if types.Implements(elemType, spi.iface()) {
		p.addStep(t, elemType, sliceMarker, true)
	}

	return nil
}

func (p *pathGen) basicMethod(t types.Type, basic *types.Basic, spi generatorSPI) error {
	return nil
}

func (p *pathGen) debugString() *jen.Statement {
	var switchCases []jen.Code

	for _, step := range p.steps {
		stepName := step.AsEnum()

		// Generate the debug string using the helper function
		var debugStr string
		if step.name == sliceMarker {
			debugStr = fmt.Sprintf("(%s)[]", types.TypeString(step.container, noQualifier))
		} else {
			debugStr = fmt.Sprintf("(%s).%s", types.TypeString(step.container, noQualifier), step.name)
		}

		if !step.slice {
			switchCases = append(switchCases, jen.Case(jen.Id(stepName)).Block(
				jen.Return(jen.Lit(debugStr)),
			))
			continue
		}

		switchCases = append(switchCases, jen.Case(jen.Id(stepName+"Offset")).Block(
			jen.Return(jen.Lit(debugStr+"Offset")),
		))
	}

	switchCases = append(switchCases, jen.Case(jen.Id(visitableInner)).Block(
		jen.Return(jen.Lit(visitableInner)),
	))

	debugStringMethod := jen.Func().Params(jen.Id("s").Id("ASTStep")).Id("DebugString").Params().String().Block(
		jen.Switch(jen.Id("s")).Block(switchCases...),
		jen.Panic(jen.Lit("unknown ASTStep")),
	)
	return debugStringMethod
}

var visitableInner = visitableName + "Inner"

func (p *pathGen) buildConstWithEnum() *jen.Statement {
	// Create the const block with all step constants
	var constDefs []jen.Code
	addStep := func(step string) {
		if constDefs == nil {
			// Use iota for the first constant
			constDefs = append(constDefs, jen.Id(step).Id("ASTStep").Op("=").Id("iota"))
			return
		}

		constDefs = append(constDefs, jen.Id(step))
	}
	for _, step := range p.steps {
		stepName := step.AsEnum()
		if step.slice {
			addStep(stepName + "Offset")
			continue
		}

		addStep(stepName)
	}

	addStep(visitableInner)

	constBlock := jen.Const().Defs(constDefs...)
	return constBlock
}

func (p *pathGen) generateGetNodeFromPath(spi generatorSPI) *jen.Statement {
	method := jen.Func().Id("GetNodeFromPath").Params(
		jen.Id("node").Id(p.ifaceName),
		jen.Id("path").Id("ASTPath"),
	).Id(p.ifaceName).Block(
		jen.For(jen.Len(jen.Id("path")).Op(">=").Lit(2)).Block(
			jen.Id("step").Op(":=").Id("path").Dot("nextPathStep").Call(),
			jen.Id("path").Op("=").Id("path[2:]"),
			jen.Switch(jen.Id("step")).Block(p.generateWalkCases(spi)...),
		),
		jen.Return(jen.Id("node")), // Fallback return
	)
	return method
}

func (p *pathGen) generateWalkCases(spi generatorSPI) []jen.Code {
	var cases []jen.Code

	for _, step := range p.steps {
		stepName := step.AsEnum()

		// Check if the type implements the interface
		if !types.Implements(step.container, spi.iface()) {
			continue
		}

		if !step.slice {
			// return GetNodeFromPath(node.(*RefContainer).ASTType, path)
			t := types.TypeString(step.container, noQualifier)

			cases = append(cases, jen.Case(jen.Id(stepName)).Block(
				jen.Id("node").Op("=").Id("node").Dot(fmt.Sprintf("(%s)", t)).Dot(step.name),
			))
			continue
		}

		var assignNode jen.Code
		t := types.TypeString(step.container, noQualifier)
		if step.name == sliceMarker {
			assignNode = jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Index(jen.Id("idx"))
		} else {
			assignNode = jen.Id("node").Dot(fmt.Sprintf("(%s)", t)).Dot(step.name).Index(jen.Id("idx"))
		}

		cases = append(cases, jen.Case(jen.Id(stepName+"Offset")).Block(
			jen.Id("idx, bytesRead").Op(":=").Id("path").Dot("nextPathOffset").Call(),
			jen.Id("path").Op("=").Id("path[bytesRead:]"),
			jen.Id("node").Op("=").Add(assignNode),
		))
	}

	/*
		case VisitableInner:
		node = node.(Visitable).VisitThis()
	*/
	cases = append(cases, jen.Case(jen.Id(visitableInner)).Block(
		jen.Id("node").Op("=").Id("node").Assert(jen.Id("Visitable")).Dot("VisitThis").Call(),
	))

	cases = append(cases, jen.Default().Block(
		jen.Return(jen.Nil()),
	))
	return cases
}
